/*
*  This file is part of ygg-brute
*  Copyright (c) 2020 ygg-brute authors
*  See LICENSE for licensing information
*/

#include <bitset>
#include <cstdint>
#include <vector>
#include <iostream>

#include "cuda/common.hpp"

#define DECLSPEC __host__ __device__

#include "generic/random.h"
#include "generic/bignum.h"

#undef DECLSPEC
#define DECLSPEC __device__

#include "cuda/basepoint_mul_fold_table_32_def.h"
#include "generic/const_def.h"
#include "cuda/common.cuh"
#include "cuda/field.cuh"
#include "generic/point.h"
#include "generic/scalar_mul.h"
#include "cuda/sha512.cuh"
#include "cuda/node_id.cuh"

#undef DECLSPEC

#include "generic/kernel/common.h"
#include "cuda/engine.hpp"
#include "cuda/util.hpp"

#define KERNEL_SPEC(NAME, BSIZE) KERNEL void NAME

namespace cuda {

namespace {

constexpr size_t MIDSTATE_ITERATIONS = 8192;
constexpr size_t BLOCKS_PER_MP = 256;
constexpr size_t DEFAULT_BLOCK_SIZE = 512;
constexpr size_t DEFAULT_INV_BATCH_SIZE = 256;

#define ED25519_D2 {649261401ul, 3956710292ul, 2189668694ul, 14685338ul, 4008956208ul, 428769522ul, 1457519847ul, 604428764ul}

static const AffineNielsPoint basepoint_mul_fold_table[256] = ED25519_BASEPOINT_MUL_FOLD_TABLE_32_DEF;

__constant__ FieldElement ed25519_d2 = ED25519_D2;

__constant__ Bn256 l_3 = BN_L_3;
__constant__ Bn256 l_4 = BN_L_4;
__constant__ Bn256 l_5 = BN_L_5;
__constant__ Bn256 l_6 = BN_L_6;
__constant__ Bn256 l_7 = BN_L_7;
__constant__ Bn256 l_minus_one = BN_L_MINUS_ONE;
__constant__ Bn256 l_3_lb = BN_L_3_LB;
__constant__ Bn256 l_7_ub = BN_L_7_UB;

__constant__ FieldElement two = {2, 0, 0, 0, 0, 0, 0, 0};

namespace host {

const Bn256 addend_mask = BN_ADDEND_MASK;

const Bn256 l_3 = BN_L_3;
const Bn256 l_4 = BN_L_4;
const Bn256 l_5 = BN_L_5;
const Bn256 l_6 = BN_L_6;
const Bn256 l_7 = BN_L_7;

} // namespace host

__host__ __device__ inline void dbg_fe(const char* what, const FieldElement fe)
{
    printf("%s{", what);
    for(auto i = 0; i < FE_SIZE; ++i)
        printf("%u, ", fe[i]);

    printf("}\n");
}

__host__ __device__ inline void dbg_point(const char* what, const EdwardsPoint& p)
{
    printf("%s{\n", what);
    dbg_fe("\tX: ", p.x);
    dbg_fe("\tY: ", p.y);
    dbg_fe("\tZ: ", p.z);
    dbg_fe("\tT: ", p.t);
    printf("}\n");
}

__host__ __device__ inline void dbg_point(const char* what, const AffineNielsPoint& p)
{
    printf("%s{\n", what);
    dbg_fe("\tY + X: ", p.y_plus_x);
    dbg_fe("\tY - X: ", p.y_minus_x);
    dbg_fe("\tXY2D: ", p.xy2d);
    printf("}\n");
}

#include "generic/kernel/precompute_addends.h"
#include "generic/kernel/compute_midstate.h"
#include "generic/kernel/compute_uv.h"
#include "generic/kernel/batch_invert.h"

KERNEL
void compute_address_kernel_v1(
    uint8_t* addresses,
    const uint32_t* __restrict__ us,
    const uint32_t* __restrict__ vs
) {
    const size_t gid = global_id();
    const auto gsize = global_size();

    uint32_t node_id[8];

    FieldElement t, u, v;
    fe_load_global(u, us);
    fe_load_global(v, vs);

    fe_mul(t, u, v);
    fe_reduce(t);

    sha512_of_32_byte_block(t, node_id);
    node_id_to_ipv6(node_id, addresses + 16 * gid);
}

class CudaGeneratorCtx final : public ::GeneratorCtx {
public:
    size_t batch_size() override { return batch_size_; }
    const uint8_t* address(size_t i) override { return addresses_.get() + 16 * i; }
    const uint8_t* skey(size_t i) override {
        auto& midstate = midstates_[i / 2];
        auto& addend = addends_[seq_];
        uint8_t to8 = (8 - (midstate[0] % 8)) & 0x07;

        Bn256 chosen_l;

        if(to8 == host::l_3[0] % 8) {
            bn_copy(chosen_l, host::l_3);
        } else if(to8 == host::l_4[0] % 8) {
            bn_copy(chosen_l, host::l_4);
        } else if(to8 == host::l_5[0] % 8) {
            bn_copy(chosen_l, host::l_5);
        } else if(to8 == host::l_6[0] % 8) {
            bn_copy(chosen_l, host::l_6);
        } else if(to8 == host::l_7[0] % 8) {
            bn_copy(chosen_l, host::l_7);
        } else {
            throw std::runtime_error{"Unable to choose L * n"};
        }

        Bn256 sk;
        bn_add(sk, midstate, chosen_l);
        if(i % 2 == 0)
            bn_add(sk, sk, addend);
        else
            bn_sub(sk, sk, addend);

        // TODO : correct conversion
        memcpy(skey_, sk, 32);
        return skey_;
    }

    std::shared_ptr<const Bn256[]> midstates_;
    std::shared_ptr<const Bn256[]> addends_;
    std::shared_ptr<uint8_t[]> addresses_;
    size_t batch_size_;
    uint64_t seq_;
    uint8_t skey_[32];
};

class CudaGenerator final : public ::Generator
{
public:
    CudaGenerator(const GeneratorParams& params)
    : block_size_{params.block_size}
    , n_blocks_{params.n_blocks}
    , batch_size_{block_size_ * n_blocks_}
    , inv_batch_size_{params.inv_batch_size}
    , midstate_ypx_{make_device_ptr<uint32_t[]>(FE_SIZE * batch_size_)}
    , midstate_ymx_{make_device_ptr<uint32_t[]>(FE_SIZE * batch_size_)}
    , midstate_xy_{make_device_ptr<uint32_t[]>(FE_SIZE * batch_size_)}
    , seeds_{batch_size_}
    , midstate_scalars_{batch_size_}
    , basepoint_mul_fold_table_{256}
    , addend_sum_{1}
    , midstate_iterations_buf_{1}
    , us_{make_device_ptr<uint32_t[]>(2 * FE_SIZE * batch_size_)}
    , us_inv_{make_device_ptr<uint32_t[]>(2 * FE_SIZE * batch_size_)}
    , vs_{make_device_ptr<uint32_t[]>(2 * FE_SIZE * batch_size_)}
    , ts_{make_device_ptr<uint32_t[]>(2 * FE_SIZE * batch_size_)}
    , addresses_{make_device_ptr<uint8_t[]>(16 * 2 * batch_size_)}
    , seq_buf_{make_device_ptr<uint64_t[]>(1)}
    , inv_batch_size_buf_{make_device_ptr<size_t[]>(1)}
    {
        CUDA_CHECK(cudaSetDevice(params.device));
        rng_init(&rng_, params.seed, params.seq);
        copy_to_device(inv_batch_size_buf_.get(), &inv_batch_size_, 1, stream_);
        *midstate_iterations_buf_.host() = MIDSTATE_ITERATIONS;
        midstate_iterations_buf_.copy_to_device(stream_);
    }

private:
    // Generator
    void produce(::GeneratorCtx& raw_ctx) {
        if(seq_ == 0) init_midstate();
        compute_addresses(dynamic_cast<CudaGeneratorCtx&>(raw_ctx));
    }

private:
    void init_addends()
    {
        auto addend_scalars = std::make_unique<Bn256[]>(MIDSTATE_ITERATIONS);
        DualBuffer<Bn256> addend_scalars_buffer{MIDSTATE_ITERATIONS};

        memcpy(basepoint_mul_fold_table_.host(), basepoint_mul_fold_table, sizeof(basepoint_mul_fold_table));
        basepoint_mul_fold_table_.copy_to_device(stream_);

        Bn256 a_sum;
        bn_zero(a_sum);

        for(auto i = 0; i < MIDSTATE_ITERATIONS; ++i) {
            Bn256 a;
            rng_fill_bytes(&rng_, a, 32);
            bn_and(a, a, host::addend_mask);
            bn_add(a_sum, a_sum, a);

            bn_copy(addend_scalars[i], a);
            bn_copy(addend_scalars_buffer.host()[i], a);
        }

        addends_ypx_ = make_device_ptr<uint32_t[]>(FE_SIZE * MIDSTATE_ITERATIONS);
        addends_ymx_ = make_device_ptr<uint32_t[]>(FE_SIZE * MIDSTATE_ITERATIONS);
        addends_xy_ = make_device_ptr<uint32_t[]>(FE_SIZE * MIDSTATE_ITERATIONS);

        addend_scalars_shared_ = std::move(addend_scalars);
        bn_copy(*addend_sum_.host(), a_sum);

        addend_scalars_buffer.copy_to_device(stream_);
        addend_sum_.copy_to_device(stream_);

        auto midstate_blocks = MIDSTATE_ITERATIONS / block_size_ + (MIDSTATE_ITERATIONS % block_size_ != 0);
        precompute_addends_kernel<<<midstate_blocks, block_size_, 0, stream_>>>(
            addends_ypx_.get(),
            addends_ymx_.get(),
            addends_xy_.get(),
            reinterpret_cast<Bn256*>(addend_scalars_buffer.device()),
            basepoint_mul_fold_table_.device(),
            midstate_iterations_buf_.device()
        );
    }

    void init_midstate()
    {
        if(!addends_ypx_) init_addends();

        midstate_scalars_shared_ = nullptr;

        rng_fill_bytes(&rng_, reinterpret_cast<uint8_t*>(seeds_.host()), seeds_.size() * sizeof(uint64_t));

        seeds_.copy_to_device(stream_);
        compute_midstate_kernel_0<<<n_blocks_, block_size_, 0, stream_>>>(
            midstate_ypx_.get(),
            midstate_ymx_.get(),
            us_.get(),
            reinterpret_cast<Bn256*>(midstate_scalars_.device()),
            seeds_.device(),
            // TODO : remove cast
            reinterpret_cast<uint8_t*>(addend_sum_.device()),
            basepoint_mul_fold_table_.device()
        );
        batch_invert_kernel_0<<<n_blocks_ / inv_batch_size_, block_size_, 0, stream_>>>(
            us_inv_.get(), us_.get(), inv_batch_size_buf_.get()
        );
        batch_invert_kernel_1<<<n_blocks_ / inv_batch_size_, block_size_, 0, stream_>>>(
            us_inv_.get(), inv_batch_size_buf_.get()
        );
        batch_invert_kernel_2<<<n_blocks_ / inv_batch_size_, block_size_, 0, stream_>>>(
            us_inv_.get(), us_.get(), inv_batch_size_buf_.get()
        );
        compute_midstate_kernel_1<<<n_blocks_, block_size_, 0, stream_>>>(
            midstate_ypx_.get(),
            midstate_ymx_.get(),
            midstate_xy_.get(),
            us_inv_.get()
        );
        midstate_scalars_.copy_to_host(stream_);
    }

    void compute_addresses(CudaGeneratorCtx& ctx)
    {
        copy_to_device(seq_buf_.get(), &seq_, 1, stream_);

        CudaEvent ev_start, ev_end;

        ev_start.record(stream_);

        compute_uv_kernel_0<<<n_blocks_, block_size_, 0, stream_>>>(
            ts_.get(), // zs
            us_inv_.get(), // efs
            midstate_xy_.get(),
            addends_xy_.get(),
            seq_buf_.get()
        );

        const size_t uv_smem_size = block_size_ * FE_SIZE * sizeof(uint32_t);

        compute_uv_kernel_1<<<
            n_blocks_2_,
            block_size_,
            uv_smem_size,
            stream_
        >>>(
            us_.get(),
            vs_.get(),
            ts_.get(), // zs
            us_inv_.get(), // efs
            midstate_ypx_.get(),
            midstate_ymx_.get(),
            addends_ypx_.get(),
            addends_ymx_.get(),
            seq_buf_.get()
        );

        batch_invert_kernel_0<<<n_blocks_2_ / inv_batch_size_, block_size_, 0, stream_>>>(
            us_inv_.get(), us_.get(), inv_batch_size_buf_.get()
        );
        batch_invert_kernel_1<<<n_blocks_2_ / inv_batch_size_, block_size_, 0, stream_>>>(
            us_inv_.get(), inv_batch_size_buf_.get()
        );
        batch_invert_kernel_2<<<n_blocks_2_ / inv_batch_size_, block_size_, 0, stream_>>>(
            us_inv_.get(), us_.get(), inv_batch_size_buf_.get()
        );

        compute_address_kernel_v1<<<n_blocks_2_, block_size_, 0, stream_>>>(
            addresses_.get(), us_inv_.get(), vs_.get()
        );

        ev_end.record(stream_);

        if(!ctx.addresses_) {
            ctx.addresses_ = make_pinned_host_ptr<uint8_t[]>(batch_size_ * 16 * 2);
        }

        ctx.addends_ = addend_scalars_shared_;
        ctx.batch_size_ = batch_size_ * 2;
        ctx.seq_ = seq_;
        memcpy_to_host(ctx.addresses_.get(), addresses_.get(), 16 * 2 * batch_size_, stream_);

        if(++seq_ >= MIDSTATE_ITERATIONS) seq_ = 0;
        stream_.sync();

        if(!midstate_scalars_shared_) {
            auto midstate_scalars_tmp = std::make_unique<Bn256[]>(batch_size_);
            memcpy(midstate_scalars_tmp.get(), midstate_scalars_.host(), batch_size_ * sizeof(Bn256));
            midstate_scalars_shared_ = std::move(midstate_scalars_tmp);
        }
        ctx.midstates_ = midstate_scalars_shared_;

        total_gpu_time_ += ev_start.elapsed_time_ms(ev_end);
        total_keys_generated_ += batch_size_ * 2;

        // std::cerr << "Gpu speed: " << (total_keys_generated_ / total_gpu_time_) << " KKeys/s" << std::endl;

        CUDA_CHECK(cudaPeekAtLastError());
    }

private:
    Rng rng_;
    const size_t n_blocks_;
    const size_t block_size_;
    const size_t batch_size_;
    const size_t n_blocks_2_{n_blocks_ * 2};
    const size_t inv_batch_size_;

    CudaStream stream_;

    DevicePtr<uint32_t[]> addends_ypx_;
    DevicePtr<uint32_t[]> addends_ymx_;
    DevicePtr<uint32_t[]> addends_xy_;
    DevicePtr<uint32_t[]> midstate_ypx_;
    DevicePtr<uint32_t[]> midstate_ymx_;
    DevicePtr<uint32_t[]> midstate_xy_;

    DevicePtr<uint8_t[]> addresses_;

    DualBuffer<uint64_t> seeds_;
    DualBuffer<Bn256> midstate_scalars_;
    DualBuffer<AffineNielsPoint> basepoint_mul_fold_table_;
    DualBuffer<Bn256> addend_sum_;
    DualBuffer<size_t> midstate_iterations_buf_;

    std::shared_ptr<const Bn256[]> midstate_scalars_shared_;
    std::shared_ptr<const Bn256[]> addend_scalars_shared_;

    DevicePtr<uint32_t[]> us_;
    DevicePtr<uint32_t[]> us_inv_;
    DevicePtr<uint32_t[]> vs_;
    DevicePtr<uint32_t[]> ts_;
    DevicePtr<uint64_t[]> seq_buf_;
    DevicePtr<size_t[]> inv_batch_size_buf_;

    uint64_t seq_{0};
    double total_gpu_time_{0};
    uint64_t total_keys_generated_{0};
};

class CudaEngine final : public Engine
{
public:
    const char* name() override { return "cuda"; }

    void print_info() override {
        auto props = get_device_props();
        printf("CUDA: %d devices\n", props.size());

        for(size_t i = 0; i < props.size(); ++i) {
            auto& prop = props[i];
            printf("Device #%d (%s)\n", i, prop.name);
            printf("\tCompute capability:         %d.%d\n", prop.major, prop.minor);
            printf("\tTotal Global Memory:        %.2fG\n", float(prop.totalGlobalMem) / 1024 / 1024 / 1024);
            printf("\tMaximum block size:         %d\n", prop.maxThreadsDim[0]);
            printf("\tClock Rate:                 %.1fMHz\n", float(prop.clockRate) / 1000);
            printf("\tNumber of muliprocessors:   %d\n", prop.multiProcessorCount);
        }
    }

    std::string device_name(size_t i) override {
        auto props = get_device_props();
        if(i >= props.size())
            throw std::runtime_error{"Invalid device: " + std::to_string(i)};
        return props[i].name;
    }

    void fill_params(GeneratorParams& params) override {
        auto props = get_device_props();
        if(params.device >= props.size())
            throw std::runtime_error{"Invalid device index: " + std::to_string(params.device)};

        auto& prop = props[params.device];

        if(!params.block_size)
            params.block_size = std::min<size_t>(DEFAULT_BLOCK_SIZE, prop.maxThreadsPerBlock);

        if(std::bitset<sizeof(params.block_size) * 8>(params.block_size).count() != 1)
            throw std::runtime_error{"Block size must be a power of two"};

        if(!params.n_blocks)
            params.n_blocks = BLOCKS_PER_MP * prop.multiProcessorCount;

        if(!params.inv_batch_size)
            params.inv_batch_size = std::min<size_t>(DEFAULT_INV_BATCH_SIZE, params.n_blocks * 2);

        if(params.inv_batch_size > params.n_blocks * 2)
            throw std::runtime_error{"Invert batch size is too large"};
    }

    std::unique_ptr<::Generator> make_generator(const GeneratorParams& params) override {
        return std::make_unique<CudaGenerator>(params);
    }

    std::unique_ptr<::GeneratorCtx> make_generator_context() override {
        return std::make_unique<CudaGeneratorCtx>();
    }
};

std::unique_ptr<CudaEngine> static_engine = std::make_unique<CudaEngine>();

} // anonymous namespace

Engine& get_engine()
{
    return *static_engine;
}

} // namespace cuda